package com.xzcs.filter;

import com.xzcs.util.common.URLUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.util.UUID;

/**
 * csrf
 * @author chang
 * @createDate 2016-7-20
 */
public class CsrfFilter implements Filter {
	protected final Logger log = LoggerFactory.getLogger(CsrfFilter.class);
	private String splitChar;//分割符
	private String excludeUrls;//放过的url
	private String errorPage;//错误页面
	FilterConfig filterConfig = null;
	public void init(FilterConfig filterConfig) throws ServletException {
		this.splitChar=filterConfig.getInitParameter("splitChar");
		this.excludeUrls=filterConfig.getInitParameter("excludeUrls");
		this.errorPage=filterConfig.getInitParameter("errorPage");
		this.filterConfig = filterConfig;
	}

	public void destroy() {
		this.filterConfig = null;
	}

	public void doFilter(ServletRequest request, ServletResponse response,
			FilterChain chain) throws IOException, ServletException {
		if(isExcludeUrl(request)){
			chain.doFilter(request, response);
		}else{
			HttpServletRequest req = (HttpServletRequest) request;
			HttpSession session = req.getSession();
			String sToken = (String)session.getAttribute("_csrf_");
			if(null == sToken){
				// 产生新的 token 放入 session 中
				sToken = generateToken();
				session.setAttribute("_csrf_", sToken);
				chain.doFilter(request, response);
			} else{
				// 从 HTTP 头中取得 csrftoken
				String xhrToken = req.getHeader("_csrf_");
				// 从请求参数中取得 csrftoken
				String pToken = req.getParameter("_csrf_");
				if(null != sToken && null != xhrToken && sToken.equals(xhrToken)){
					chain.doFilter(request, response);
				}else if(null != sToken && null != pToken && sToken.equals(pToken)){
					chain.doFilter(request, response);
				}else{
					request.getRequestDispatcher(errorPage).forward(request,response);
			    }
			}
		}
	}

	private boolean isExcludeUrl(ServletRequest request){
		boolean exclude=false;
		if(StringUtils.isNotBlank(excludeUrls)){
			String[]excludeUrl=excludeUrls.split(splitChar);
			if(excludeUrl!=null&&excludeUrl.length>0){
				for(String url:excludeUrl){
					if(URLUtils.getURI((HttpServletRequest)request).contains(url)){
						exclude=true;
					}
				}
			}
		}
		return exclude;
	}
	
	private String generateToken() {
	    return UUID.randomUUID().toString();
	}
}
